import matplotlib.pyplot as plt
import pandas as p
import numpy as np
from glob import glob
import seaborn as sns

def getFiles():
    direc_csv = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Data/NHTS Data - 2024/**/*.csv'
    direc_CSV = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Data/NHTS Data - 2024/**/*.CSV'
    direc_asc = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Data/NHTS Data - 2024/**/*.asc'
    direc_txt = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Data/NHTS Data - 2024/**/*.txt'
    files_csv = glob(direc_csv, recursive=True)
    files_CSV = glob(direc_CSV, recursive=True)
    files_asc = glob(direc_asc, recursive=True)
    files_txt = glob(direc_txt, recursive=True)

    files = []
    files.extend(files_csv)
    files.extend(files_CSV)
    files.extend(files_asc)
    files.extend(files_txt)

    return files

def getFileYear(file):
    split = file.split('/')

    try:
        year = int(split[-3])
    except ValueError:
        year = int(split[-4])

    return year

def getFilesForYear(year=2017):
    r = []

    for file in getFiles():
        if(getFileYear(file)==year):
            r.append(file)

    return r

def getHouseholdDataByYear(year=2017, small=False):
    file = None

    for tempFile in getFilesForYear(year):
        if('hh' in tempFile.lower()):
            file = tempFile

    if(file is None):
        print(f'Household file not found for {year}.')
        return None
    else:
        if (small):
            data = p.read_csv(file, nrows=10000)
        else:
            data = p.read_csv(file)
        safety_inspection_designation = get_state_safety_inspection_df_by_year()
        data = data.merge(safety_inspection_designation, left_on='HHSTATE', right_on='State Code')
        data.loc[:, 'file'] = file
        return data

def getVehicleDataByYear(year=2017, small=False):
    file = None

    for tempFile in getFilesForYear(year):
        if('veh' in tempFile.lower()):
            file = tempFile

    if(file is None):
        print(f'Vehicle file not found for {year}.')
        return None
    else:
        if(small):
            data = p.read_csv(file, nrows=10000)
        else:
            data = p.read_csv(file)
        data.loc[:, 'file'] = file
        data = data.loc[data.loc[:, 'VEHTYPE'].isin([1,2,3,4,5,6,7]), :]
        if ('VEHOWNED' not in data.columns and year == 2009):
            data.loc[:, 'VEHOWNED'] = data.loc[:, 'VEHOWNMO']
            data.loc[data.loc[:, 'VEHOWNMO'] > 11, 'VEHOWNED'] = 1
            data.loc[(data.loc[:, 'VEHOWNMO'] <= 11) & (data.loc[:, 'VEHOWNMO'] >= 0), 'VEHOWNED'] = 2
        return data

def getPersonDataByYear(year=2017, small=False):
    file = None

    for tempFile in getFilesForYear(year):
        if('per' in tempFile.lower()):
            file = tempFile

    if(file is None):
        print(f'Person file not found for {year}.')
        return None
    else:
        if (small):
            data = p.read_csv(file, nrows=10000)
        else:
            data = p.read_csv(file)
        return data
def getHouseholdVehicleStats(vehData):
    r = {}

    for hhID in set(vehData.loc[:, 'HOUSEID']):
        tempData = vehData.loc[vehData.loc[:, 'HOUSEID'] == hhID, 'VEHAGE']
        tempDataMileage = vehData.loc[vehData.loc[:, 'HOUSEID'] == hhID, 'BESTMILE']
        r[hhID] = {'minVehAge': np.min(tempData), 'meanVehAge': np.mean(tempData), 'maxVehAge': np.max(tempData), 'NumVehicles':len(tempData), 'MilesDriven':np.sum(tempDataMileage)}

    r = p.DataFrame.from_dict(r, orient='index')
    r.loc[:, 'HOUSEID'] = r.index

    return r

def getHouseholdDataWithVehicleInfoByYear(year=2017):
    hhData = getHouseholdDataByYear(year)
    vehData = getVehicleDataByYear(year)
    hhVehData = getHouseholdVehicleStats(vehData)

    jointData = hhData.merge(hhVehData, on='HOUSEID')
    return jointData

def summarizeDataByNumVehicles(data, cat=False):

    weightCol = 'WTHHFIN'
    relCols = ['minVehAge', 'meanVehAge', 'maxVehAge', 'MilesDriven']

    data.loc[:, 'Count'] = 1

    for relCol in relCols:
        data.loc[:, f'{relCol}-Weighted'] = data.loc[:, relCol]*data.loc[:, weightCol]

    maxVehicles = 5
    data.loc[:, 'NumVehiclesCat'] = data.loc[:, 'NumVehicles']
    data.loc[data.loc[:, 'NumVehicles']>=maxVehicles, 'NumVehiclesCat'] = f'{maxVehicles}+'

    if(cat):
        grouped_data = data.groupby(['NumVehiclesCat', 'HasSafety'], as_index=False).sum()
    else:
        grouped_data = data.groupby(['NumVehicles', 'HasSafety'], as_index=False).sum()

    for relCol in relCols:
        grouped_data.loc[:, relCol] = grouped_data.loc[:, f'{relCol}-Weighted']/grouped_data.loc[:, weightCol]

    grouped_data.loc[:, 'Safety Inspection'] = 'Has Safety Inspections'
    grouped_data.loc[grouped_data.loc[:, 'HasSafety']==False, 'Safety Inspection'] = 'No Safety Inspections'


    return grouped_data

def load_stata_data():
    file = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Marginal Mileage Project/Publication Replication Code/NBER/PrimaryModel/CleanedData/FHWACleanV4.csv'
    data = p.read_csv(file)

    return data

def summarizeDataByHouseholdSize(data):

    weightCol = 'WTHHFIN'
    relCols = ['minVehAge', 'meanVehAge', 'maxVehAge', 'numVehicles']

    data.loc[:, 'Count'] = 1

    for relCol in relCols:
        data.loc[:, f'{relCol}-Weighted'] = data.loc[:, relCol]*data.loc[:, weightCol]

    grouped_data = data.groupby('HHSIZE', as_index=False).sum()

    for relCol in relCols:
        grouped_data.loc[:, relCol] = grouped_data.loc[:, f'{relCol}-Weighted']/grouped_data.loc[:, weightCol]

    return grouped_data

def get_state_safety_inspection_dict():
    stata_data = load_stata_data()
    r = {}
    for ind, row in stata_data.iterrows():
        temp_key = (row['Year'], row['State Code'])
        temp_val = row['HasSafety']
        r[temp_key] = temp_val



    return r

def get_state_safety_inspection_df_by_year(year=2017):
    state_safety_inspection = get_state_safety_inspection_dict()
    r = {}
    for key, val in state_safety_inspection.items():
        if(year in key):
            r[key[1]] = {'HasSafety':val}

    r = p.DataFrame.from_dict(r, orient='index')
    r.loc[:, 'State Code'] = r.index
    return r

if __name__=='__main__':
    files = getFiles()
    files_by_year = {}
    for file in files:
        year = getFileYear(file)
        if year not in files_by_year.keys():
            files_by_year[year] = [file]
        else:
            files_by_year[year].append(file)


    safety_inspection = get_state_safety_inspection_dict()
    safety_inspection_2017 = get_state_safety_inspection_df_by_year()

    # hhVehData2022 = getHouseholdDataWithVehicleInfoByYear(2022)
    # sumData2022 = summarizeDataByNumVehicles(hhVehData2022)
    #
    year = 2017
    # year = 2009
    hhVehData2017 = getHouseholdDataWithVehicleInfoByYear(year)
    sumData2017 = summarizeDataByNumVehicles(hhVehData2017)
    sumDataCat2017 = summarizeDataByNumVehicles(hhVehData2017, cat=True)

    vars = ['minVehAge', 'meanVehAge', 'maxVehAge']
    clean_vars = ['Min. Household Vehicle Age', 'Mean. Household Vehicle Age', 'Max. Household Vehicle Age']

    x = 5
    gr = (1+np.sqrt(5))/2
    y = x/gr

    for var, clean_var in zip(vars, clean_vars):
        fig, ax = plt.subplots(1,1,figsize=(x,y))
        ax = sns.lineplot(data=sumData2017, x='NumVehicles', y=var, hue='Safety Inspection')
        plt.ylabel(clean_var)
        plt.xlabel('Number of Household Vehicle')
        plt.savefig(f'Plots/{var}-{year}.png', bbox_inches='tight', dpi=300)

    fig, ax = plt.subplots(1,1,figsize=(x,y))

    sumSumData2017 = sumDataCat2017.groupby('HasSafety', as_index=False).sum()
    sumSumData2017 = sumSumData2017.loc[:, ['HasSafety', 'WTHHFIN']]
    sumSumData2017 = sumSumData2017.rename({'WTHHFIN':'Sum WTHHFIN'}, axis=1)
    sumDataCat2017 = sumDataCat2017.merge(sumSumData2017, on='HasSafety')
    sumDataCat2017.loc[:, 'Norm Weight'] = sumDataCat2017.loc[:, 'WTHHFIN']/sumDataCat2017.loc[:, 'Sum WTHHFIN']

    ax = sns.barplot(data=sumDataCat2017, x='NumVehiclesCat', y='Norm Weight', hue='Safety Inspection')
    plt.ylabel('Empirical Probability Mass')
    plt.xlabel('Number of Household Vehicle')
    plt.savefig(f'Plots/NumVehicleHist-{year}.png', bbox_inches='tight', dpi=300)


    # ax = sns.barplot(data=sumData2017, x='NumVehicles', y='Norm Weight', hue='Safety Inspection')
    # plt.ylabel('Empirical Probability Mass')
    # plt.xlabel('Number of Household Vehicle')
    # plt.savefig(f'Plots/MilesDriven-{year}.png', bbox_inches='tight', dpi=300)


    # stackedData = []
    #
    # for col in ['minVehAge', 'meanVehAge', 'maxVehAge']:
    #

    fig, ax = plt.subplots(1,1,figsize=(x,y))
    ax = sns.barplot(data=sumData2017, x='NumVehicles', y='Norm Weight', hue='Safety Inspection')
    plt.ylabel('Empirical Probability Mass')
    plt.xlabel('Number of Household Vehicle')
    plt.savefig(f'Plots/MilesDriven-{year}.png', bbox_inches='tight', dpi=300)

    # print(list(hhVehData2017.columns))
    # r = {}
    #
    # for hhID in set(vehData2022.loc[:, 'HOUSEID']):
    #     tempData = vehData2022.loc[vehData2022.loc[:, 'HOUSEID']==hhID, 'VEHAGE']
    #     r[hhID] = {'min':np.min(tempData), 'mean':np.mean(tempData), 'max':np.max(tempData)}
    #
    # r = p.DataFrame.from_dict(r, orient='index')
    # r.loc[:, 'HOUSEID'] = r.index